import csv
import os


def get_csv_by_selected_img(origin_csv, target_csv, selected_img_path):
    """
    已经有了筛选后的图片和全部图片的csv文件
    根据图片获得筛选后的csv文件
    args:
    origin_csv: 原来的csv文件
    target_csv: 目标的csv文件
    selected_img_path: 筛选后的图片地址
    """
    count = 0
    with open(origin_csv, 'r', newline='', encoding='utf-8') as input_file:
        reader = csv.reader(input_file)
        for row in reader:
            if count % 10000 == 0:
                print(count)
            count += 1
            selected_rows = []
            img_ref_path = row[0].replace("\\", "/")    # subject0000/1.jpg
            img_abs_path = os.path.join(selected_img_path, img_ref_path)
            if os.path.exists(img_abs_path):
                selected_rows.append(row)

                with open(target_csv, 'a', newline='', encoding='utf-8') as output_file:
                    writer = csv.writer(output_file)
                    writer.writerows(selected_rows)


if __name__ == "__main__":
    print(1111)
    origin_csv = "/home/xian/mzs/mzs_code/Dataset/ETH/eth_all_noselect.csv"
    target_csv = "/home/xian/mzs/mzs_code/Dataset/ETH/eth_all_selected.csv"
    selected_img_path = "/home/xian/mzs/mzs_code/Dataset/ETH/result/all_landmark"
    get_csv_by_selected_img(origin_csv, target_csv, selected_img_path)


